#####
# Script to generate figures 3 and 6.
#####

library(ggplot2)
library(cowplot)
library(viridis)
library(RColorBrewer)
library(patchwork)
library(tidyverse)
library(data.table)

library(pracma)
library(nleqslv)
library(rootSolve)
library(fields)
library(chebpol)
library(scales)

library(doParallel)
library(Rcpp)

library(simFnsEqns)

source("equations.r")
source("simulation_functions.r")
source("planner_simulation.r")

p <- 1
discount <- 0.99
r <- ((1 - discount)/(discount))
fe_eqm <- 0.06
F <- p/(r+fe_eqm)
d <- 0.6
# d <- 0.05
m <- 2
aSS <- 0.05
aSD <- 0.005
aDD <- 0.09 
bSS <- 10
bSD <- 5.375
bDD <- 1.3
Xbar <- 10
T <- 150
label <- "benchmark"

params <- data.frame(p=p, F=F, discount=discount, r=r, fe_eqm=fe_eqm, d=d, m=m, aSS=aSS, aSD=aSD, aDD=aDD, bSS=bSS, bSD=bSD, bDD=bDD, T=T, Xbar=Xbar)

## Make value function guesses on: a small grid (20 points, S between 0 and 1, D between 0 and 10) and a large grid (200 points, S between 0 and 1, D between 0 and 10)
SD_grid_s <- make_sd_grid(0,1,20,0,10,20)
SD_grid <- make_sd_grid(0,1,200,0,10,200)
grid_list <- list(S = unique(SD_grid$S), D = unique(SD_grid$D)) # make a list of grid elements

# Solve for optimal policy on small grid
benchmark_policy <- solve_opt_policy(params=params, small_guess_grid=SD_grid_s, larger_vfi_grid=SD_grid, grid_list=grid_list, ncores=70, label=label)

# Calculate phase diagrams, stable basins
basins_results <- find_stable_basins(params=params, optimal_policy=benchmark_policy, SD_grid=SD_grid, grid_list=grid_list, ncores=70, label=label)

oa_dfrm <- fread(paste0("../data/oa_",label,"_basin.csv"))[,c("satellites","debris","X","path_convergence")]
opt_dfrm <- fread(paste0("../data/opt_",label,"_basin.csv"))[,c("satellites","debris","X","path_convergence")]

## Calibration scale factors.
s_scale_factor <- 116649.8/1.237508
d_scale_factor <- 367372.3/1.170672

# Make figures with policy and phase+basin. Blue for satellite nullcline, red for debris nullcline, black full circle for stable steady state, open circle for unstable steady state.
results <- make_policy_phase_plots(oa_dfrm, opt_dfrm, s_scale_factor, d_scale_factor)

ggsave(
	paste0("../images/figure-3--oa-policy-phase.png"),
	((results$oa_policy | results$oa_phase)) + plot_annotation(tag_levels = 'A') & plot_layout(guides = "collect") & theme(legend.position = 'right'),
	width = 16*3/2,
	height = 9*3/2*3/4,
	units = "in",
	dpi = 300)

ggsave(
	paste0("../images/figure-6--opt-policy-phase.png"),
	((results$opt_policy | results$opt_phase)) + plot_annotation(tag_levels = 'A') & plot_layout(guides = "collect") & theme(legend.position = 'right'),
	width = 16*3/2,
	height = 9*3/2*3/4,
	units = "in",
	dpi = 300)
